import copy
import math
import os
from typing import Union, Tuple, Dict

import numpy as np
import torch
from scipy.sparse import csr_matrix

from controlsnr import find_a_given_snr, solve_ab


def simple_collate_fn(batch):
    """
    Stitched batch, for all images with the same size.
    Return:
      - adj: [B, N, N]
      - labels: [B, N]
    """
    adjs = [torch.tensor(sample['adj'].toarray(), dtype=torch.float32) for sample in batch]
    labels = [torch.tensor(sample['labels'], dtype=torch.long) for sample in batch]

    adj_batch = torch.stack(adjs)       # [B, N, N]
    label_batch = torch.stack(labels)   # [B, N]

    return {
        'adj': adj_batch,
        'labels': label_batch
    }

def sample_theta(mode, N, rng, theta_kwargs=None):
    """
    Correction coefficient θ by mode sampling:
      - train: 50% using SBM (θ=1), 50% using DCBM (Gamma distribution; shape parameter κ randomly drawn)
      - val/test: always SBM (θ=1)

        Parameters (optional, placed in theta_kwargs):
      - gamma_shape_mode: 'random' | 'fixed', default 'random'
      - gamma_shape_range: (low, high), κ value range (log-uniform extraction), default (0.3, 5.0)
      - gamma_shape: κ used when gamma_shape_mode='fixed', default 1.0
      - clip_quantile: Crop the upper quantile of Gamma sampling results (e.g. 0.999), default None (no cropping)
      - clip_max: If given, max crop θ (or clip_quantile or neither)
      - normalize_mean: Whether to finally normalize θ to mean=1, the default is True
    """

    if theta_kwargs is None:
        theta_kwargs = {}

    # Select the distribution: train randomly half Gamma, half ones; val/test fixes ones
    if mode in ("train", "val"):
        theta_dist = "ones" if (rng.random() < 0.5) else "gamma"
    else:
        theta_dist = "ones"

    if theta_dist == "ones":
        theta = np.ones(N, dtype=float)
        return theta

    # === Gamma distribution：Gamma(kappa, scale=1/kappa) -> mean=1 ===
    shape_mode = theta_kwargs.get("gamma_shape_mode", "random")

    if shape_mode == "random":
        lo, hi = theta_kwargs.get("gamma_shape_range", (1.5, 3.0))
        # Uniform extraction with logarithm to cover narrow/wide heterogeneity
        kappa = float(np.exp(rng.uniform(np.log(lo), np.log(hi))))
    elif shape_mode == "fixed":
        kappa = float(theta_kwargs.get("gamma_shape", 1.0))
    else:
        raise ValueError(f"Unsupported gamma_shape_mode: {shape_mode}")

    scale = 1.0 / kappa  # Guaranteed mean = 1
    theta = rng.gamma(shape=kappa, scale=scale, size=N).astype(float)

    # 可选：裁剪极端大 hub
    clip_q = theta_kwargs.get("clip_quantile", None)
    clip_max = theta_kwargs.get("clip_max", None)
    if clip_q is not None:
        qv = float(clip_q)
        if 0.0 < qv < 1.0:
            q = float(np.quantile(theta, qv))
            theta = np.minimum(theta, q)
    if clip_max is not None:
        theta = np.minimum(theta, float(clip_max))

    return theta

# ---- Config ----
per_cell_tr = 4
per_cell_v = 1

snr_train = np.logspace(np.log10(0.5), np.log10(3), 15)
# gamma: fixed 4 pts
gamma_train = np.array([0.30, 1.20, 3.00, 4.00 ,5.00])

# Validation_set
snr_mid = np.sqrt(snr_train[:-1] * snr_train[1:])
# 从中均匀挑 10 个
idx = np.linspace(0, len(snr_mid) - 1, 10, dtype=int)
snr_val = snr_mid[idx]
gamma_val = gamma_train.copy()

# Optional: Quickly check that the validation point does not coincide with the training point
assert set(snr_val).isdisjoint(snr_train)
# assert set(C_val).isdisjoint(C_train)

# —— testSet（=200）
snr_test = (0.60,)
gamma_test = (0.15,)
# C_test = (10.0,)
per_cell_te = 1

def _normalize_theta_global(theta: np.ndarray) -> np.ndarray:
    """
    Global normalization:
    Scale θ for all nodes so that the mean = 1
    (i.e. sum(theta) = N)
    """
    theta = np.asarray(theta, dtype=float)
    s = theta.sum()
    N = len(theta)
    if s > 0:
        theta = theta * (N / s)
    return theta

class Generator(object):
    def __init__(self, N_train=50, N_test=100, N_val = 50,generative_model='SBM_multiclass', p_SBM=0.8, q_SBM=0.2, n_classes=2, path_dataset='dataset',
                 num_examples_train=100, num_examples_test=10, num_examples_val=10):
        self.N_train = N_train
        self.N_test = N_test
        self.N_val = N_val

        self.generative_model = generative_model
        self.p_SBM = p_SBM
        self.q_SBM = q_SBM
        self.n_classes = n_classes
        self.path_dataset = path_dataset

        self.data_train = None
        self.data_test = None
        self.data_val = None

        self.num_examples_train = num_examples_train
        self.num_examples_test = num_examples_test
        self.num_examples_val = num_examples_val

        # 初始化时直接生成
        self.C_train = self._make_C_grid(self.n_classes)
        # Validation_set
        c_mid = np.sqrt(self.C_train[:-1] * self.C_train[1:])
        # 从中均匀挑 10 个
        idx = np.linspace(0, len(c_mid) - 1, 10, dtype=int)
        self.C_val = c_mid[idx]
        self.C_test = (10,)  # 你可以换成别的逻辑

    def compute_C_bounds(self,k, margin=0.0):
        """返回 (C_min, C_max)"""
        snr_max = 3
        # 下界
        C_min = k * snr_max * (1.0 + float(margin))
        C_max = C_min + k * 6
        return C_min, C_max

    def _make_C_grid(self,k):
        """生成 log-uniform 网格"""
        C_min, C_max = self.compute_C_bounds(k)
        if C_max <= C_min:
            return [C_min]
        return np.exp(np.linspace(np.log(C_min), np.log(C_max), 15))


    def SBM(self, p, q, N):
        W = np.zeros((N, N))

        p_prime = 1 - np.sqrt(1 - p)
        q_prime = 1 - np.sqrt(1 - q)

        n = N // 2

        W[:n, :n] = np.random.binomial(1, p, (n, n))
        W[n:, n:] = np.random.binomial(1, p, (N-n, N-n))
        W[:n, n:] = np.random.binomial(1, q, (n, N-n))
        W[n:, :n] = np.random.binomial(1, q, (N-n, n))
        W = W * (np.ones(N) - np.eye(N))
        W = np.maximum(W, W.transpose())

        perm = torch.randperm(N).numpy()
        blockA = perm < n
        labels = blockA * 2 - 1

        W_permed = W[perm]
        W_permed = W_permed[:, perm]
        return W_permed, labels


    def SBM_multiclass(self, p, q, N, n_classes):

        p_prime = 1 - np.sqrt(1 - p)
        q_prime = 1 - np.sqrt(1 - q)

        prob_mat = np.ones((N, N)) * q_prime

        n = N // n_classes  # 基础类别大小
        remainder = N % n_classes  # 不能整除的剩余部分
        n_last = n + remainder  # 最后一类的大小

        # 先对整除部分进行块状分配
        for i in range(n_classes - 1):  # 处理前 n_classes-1 类
            prob_mat[i * n: (i + 1) * n, i * n: (i + 1) * n] = p_prime

        # 处理最后一类
        start_idx = (n_classes - 1) * n  # 最后一类的起始索引
        prob_mat[start_idx: start_idx + n_last, start_idx: start_idx + n_last] = p_prime

        # 生成邻接矩阵
        W = np.random.rand(N, N) < prob_mat
        W = W.astype(int)

        W = W * (np.ones(N) - np.eye(N))  # 移除自环
        W = np.maximum(W, W.transpose())  # 确保无向图

        # 随机打乱节点顺序
        perm = torch.randperm(N).numpy()

        # 生成类别标签
        labels =np.minimum((perm // n) , n_classes - 1)

        W_permed = W[perm]
        W_permed = W_permed[:, perm]

        #计算P矩阵的特征向量
        prob_mat_permed = prob_mat[perm][:, perm]
        # np.fill_diagonal(prob_mat_permed, 0)  # 去除自环

        eigvals, eigvecs = np.linalg.eigh(prob_mat_permed)
        idx = np.argsort(eigvals)[::-1]
        eigvecs_top = eigvecs[:, idx[:n_classes]]

        return W_permed, labels, eigvecs_top  # 返回前n_classes特征向量

    def imbalanced_SBM_multiclass(self, p, q, N, n_classes, class_sizes):

        # 上三角采样不会放大概率，直接用目标 p, q
        p_prime = float(p)
        q_prime = float(q)

        # 构造期望矩阵（块内 p，块间 q），无自环
        prob_mat = np.full((N, N), q_prime, dtype=float)
        boundaries = np.cumsum([0] + class_sizes)
        for i in range(n_classes):
            start, end = boundaries[i], boundaries[i + 1]
            prob_mat[start:end, start:end] = p_prime
        np.fill_diagonal(prob_mat, 0.0)

        # —— 关键修改：只采样上三角，然后镜像 —— #
        W = np.zeros((N, N), dtype=np.uint8)
        iu, ju = np.triu_indices(N, k=1)
        W[iu, ju] = (np.random.rand(iu.size) < prob_mat[iu, ju]).astype(np.uint8)
        W = (W + W.T).astype(np.uint8)  # 无向化；对角仍为 0

        # 打乱节点顺序
        perm = torch.randperm(N).numpy()

        # 生成并置乱标签
        labels = np.zeros(N, dtype=int)
        for i in range(n_classes):
            start, end = boundaries[i], boundaries[i + 1]
            start, end = boundaries[i], boundaries[i + 1]
            labels[start:end] = i
        labels = labels[perm]

        # 同步置乱矩阵
        W_permed = W[perm][:, perm]

        # 置乱后的期望矩阵用于特征分解（与 W_permed 对齐）
        prob_mat_permed = prob_mat[perm][:, perm]
        eigvals, eigvecs = np.linalg.eigh(prob_mat_permed)
        idx = np.argsort(eigvals)[::-1]
        eigvecs_top = eigvecs[:, idx[:n_classes]]

        return W_permed, labels, eigvecs_top

    def imbalanced_DCSBM_multiclass(self, B_prob, labels, theta, *, rng=None, return_eigvecs=False, topk=8):
        """
        返回: W_dense(or sparse), labels, (可选) eigvecs_top
        """
        import numpy as np
        rng = np.random.default_rng() if rng is None else rng

        N = len(labels)
        k = B_prob.shape[0]
        labels = np.asarray(labels, dtype=int)
        theta = np.asarray(theta, dtype=float)

        # 用广播构造 dense 概率矩阵（简洁版）
        # 先构一个 (N, N) 的块 B_gij：通过行列映射索引
        B_rows = B_prob[labels]  # (N, k)
        B_gij = B_rows[:, labels]  # (N, N) —— 等价 Z B Z^T 的取值

        P = (theta[:, None] * theta[None, :]) * B_gij
        np.fill_diagonal(P, 0.0)
        # 数值安全：clip 到 [0,1]
        P = np.clip(P, 0.0, 1.0)

        # 采样
        A = (rng.random((N, N)) < P).astype(np.float32)
        # 无向图对称化
        A = np.triu(A, 1)
        A = A + A.T

        eigvecs_top = None
        if return_eigvecs:
            # 简单返回前 topk 个特征向量（可替换为 Bethe Hessian / 归一化拉普拉斯）
            w, v = np.linalg.eigh(A)
            idx = np.argsort(w)[::-1][:topk]
            eigvecs_top = v[:, idx].astype(np.float32, copy=False)

        return A, labels, eigvecs_top


    def prepare_data(self):
        def get_npz_dataset(path, mode, *, snr_grid, gamma_grid, C_grid, per_cell, min_size=50, base_seed=0):
            if not os.path.exists(path):
                os.makedirs(path)
                print(f"[创建数据集] {mode} 数据目录不存在，已新建：{path}")

            npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
            if not npz_files:
                print(f"[创建数据集] {mode} 数据未找到，开始生成...")
                self.create_dataset_grid_dcsbm(
                    path, mode=mode,
                    snr_grid=snr_grid,
                    gamma_grid=gamma_grid,
                    C_grid=C_grid,
                    per_cell=per_cell,
                    min_size=min_size,
                    base_seed=base_seed
                )
                npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
            else:
                print(f"[读取数据] {mode} 集已存在，共 {len(npz_files)} 张图：{path}")
            return [os.path.join(path, f) for f in npz_files]

        # ==== 目录 ====
        train_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gstr{self.N_train}_numtr{self.num_examples_train}"
        test_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gste{self.N_test}_numte{self.num_examples_test}"
        val_dir = f"{self.generative_model}_nc{self.n_classes}_rand_val{self.N_val}_numval{self.num_examples_val}"

        train_path = os.path.join(self.path_dataset, train_dir)
        test_path = os.path.join(self.path_dataset, test_dir)
        val_path = os.path.join(self.path_dataset, val_dir)

        # ==== 采用上面的三套参数 ====
        self.data_train = get_npz_dataset(
            train_path, 'train',
            snr_grid=snr_train, gamma_grid=gamma_train, C_grid=self.C_train, per_cell=per_cell_tr,
            min_size=10, base_seed=123
        )
        self.data_val = get_npz_dataset(
            val_path, 'val',
            snr_grid=snr_val, gamma_grid=gamma_val, C_grid=self.C_val, per_cell=per_cell_v,
            min_size=10, base_seed=2025
        )
        self.data_test = get_npz_dataset(
            test_path, 'test',
            snr_grid=snr_test, gamma_grid=gamma_test, C_grid=snr_test, per_cell=per_cell_te,
            min_size=10, base_seed=31415
        )


    def sample_single(self, i, is_training=True):
        if is_training:
            dataset = self.data_train
        else:
            dataset = self.data_test
        example = dataset[i]
        if (self.generative_model == 'SBM_multiclass'):
            W_np = example['W']
            labels = np.expand_dims(example['labels'], 0)
            labels_var = torch.from_numpy(labels)
            if is_training:
                labels_var.requires_grad = True
            return W_np, labels_var


    def sample_otf_single(self, is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)
        elif self.generative_model == 'SBM_multiclass':
            W, labels,eigvecs_top = self.SBM_multiclass(self.p_SBM, self.q_SBM, N, self.n_classes)
        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top

    def imbalanced_sample_otf_single(self, class_sizes , is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)
        elif self.generative_model == 'SBM_multiclass':
            W, labels,eigvecs_top = self.imbalanced_SBM_multiclass(self.p_SBM, self.q_SBM, N, self.n_classes, class_sizes)
        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top


    def random_sample_otf_single(self, C = 10 ,is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)

        elif self.generative_model == 'SBM_multiclass':
            a_low, b_low = find_a_given_snr(0.1, self.n_classes, C)
            a_high, b_high = find_a_given_snr(1, self.n_classes, C)

            lower_bound = a_low / b_low
            upper_bound = a_high / b_high

            if lower_bound > upper_bound:
                lower_bound, upper_bound = upper_bound, lower_bound

            p, q, class_sizes, snr = self.random_imbalanced_SBM_generator_balanced_sampling(
                N=N,
                n_classes=self.n_classes,
                C=C,
                alpha_range=(lower_bound, upper_bound),
                min_size= 20
            )
            W, labels,eigvecs_top = self.imbalanced_SBM_multiclass(p, q, N, self.n_classes, class_sizes)

        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top, snr, class_sizes


    def random_imbalanced_SBM_generator_balanced_sampling(self, N, n_classes, C, *,
                                        alpha_range=(1.3, 2.8),
                                        min_size=5):
        """
        随机生成 SBM 模型的参数，社区大小为随机比例但总和为 N。
        返回 p, q, class_sizes, a, b, snr。
        """
        assert N >= min_size * n_classes

        # Step 1: 随机生成 a > b，使得 a + (k - 1) * b = C
        alpha = np.random.uniform(*alpha_range)
        b = C / (alpha + (n_classes - 1))
        a = alpha * b

        # Step 2: 计算边连接概率
        logn = np.log(N)
        p = a * logn / N
        q = b * logn / N

        # ✅ Step 3: 使用 Dirichlet 生成 class_sizes
        remaining = N - min_size * n_classes
        probs = np.random.dirichlet(np.ones(n_classes))  # 总和为1的概率向量
        extras = np.random.multinomial(remaining, probs)
        class_sizes = [min_size + e for e in extras]

        # Step 4: 计算 SNR
        snr = (a - b) ** 2 / (n_classes * (a + (n_classes - 1) * b))

        return p, q, class_sizes, snr

    def _sample_class_sizes_dirichlet(
            self,
            N: int,
            n_classes: int,
            gamma: float,
            min_size: int,
            rng: Union[int, np.random.Generator],
            gamma_jitter: float = 0.5,
            return_labels: bool = False,  # 新增：是否返回逐点 labels
            shuffle_labels: bool = True,  # 新增：是否打乱节点顺序
            eps: float = 1e-12  # 数值下界，避免 gamma_used 过小
    ) -> Union[list, Tuple[list, np.ndarray, Dict]]:

        # --- RNG 统一化 ---
        if isinstance(rng, (int, np.integer)):
            rng = np.random.default_rng(int(rng))

        assert N >= min_size * n_classes, "N 必须 >= min_size * n_classes"
        remaining = N - min_size * n_classes

        # --- gamma 抖动并做下界裁剪 ---
        if gamma_jitter and gamma_jitter > 0:
            mult = rng.uniform(max(0.0, 1.0 - gamma_jitter), 1.0 + gamma_jitter)
            gamma_used = max(eps, float(gamma) * float(mult))
        else:
            gamma_used = max(eps, float(gamma))

        alpha = np.full(n_classes, gamma_used, dtype=float)

        # --- 采样类别比例（若 remaining=0，给个均匀兜底） ---
        probs = rng.dirichlet(alpha) if remaining > 0 else np.full(n_classes, 1.0 / n_classes)

        # --- 分配剩余名额 ---
        extras = rng.multinomial(remaining, probs) if remaining > 0 else np.zeros(n_classes, dtype=int)
        sizes = (min_size + extras).astype(int).tolist()

        if not return_labels:
            return sizes  # 与旧代码保持一致

        # --- 展开成逐点标签 (N,) ---
        labels = np.concatenate([np.full(sz, c, dtype=int) for c, sz in enumerate(sizes)])
        if shuffle_labels:
            labels = labels[rng.permutation(N)]

        meta = dict(gamma_used=gamma_used, probs=probs, sizes=np.array(sizes, dtype=int))
        return sizes, labels, meta


    def gen_one_dcsbm_by_targets(
            self, N, n_classes, C, mode,target_snr, gamma, min_size=5, *, rng=None,
            heterophily=False, hetero_prob=None,
            # 抖动与约束
            ab_jitter=0.05, keep_assortativity=True,  # 同/异配方向保持
            pq_jitter=(0.02, 0.05),  # (p_jitter, q_jitter)
            C_jitter=0.1, C_jitter_mode='relative',
            a_floor=1e-8, b_floor=1e-8,
            # θ 分布
            theta_dist='pareto', theta_kwargs=None,
            normalize_theta=True
    ):
        """
        用目标 (C, target_snr, gamma) 生成一张 DCSBM 的参数:
          - 不再强制 a+(k-1)b=C，仅对 (a,b) 做轻微抖动
          - 返回: B_prob, labels, theta, a, b, gamma, is_hetero, C_used
        """
        from scipy.optimize import fsolve  # 假设 solve_ab 用到了 fsolve
        rng = np.random.default_rng() if rng is None else rng

        # === 0) C 轻微抖动（可选）===
        C_used = float(C)
        if C_jitter and C_jitter > 0:
            if C_jitter_mode == 'relative':
                C_used *= rng.uniform(1.0 - float(C_jitter), 1.0 + float(C_jitter))
            elif C_jitter_mode == 'absolute':
                C_used += rng.uniform(-float(C_jitter), float(C_jitter))
            C_used = max(C_used, 1e-6)

        # === 1) 采样社区大小 & 标签 ===
        sizes, labels, meta = self._sample_class_sizes_dirichlet(
            N=N, n_classes=n_classes, gamma=gamma, min_size=min_size, rng=rng,
            return_labels=True, shuffle_labels=True
        )

        # === 2) 由 (target_snr, 原始 C, k) 解 a0, b0（你已有的求解器）===
        a0, b0 = find_a_given_snr(target_snr, n_classes, total_ab = C)
        # 同/异配切换
        if hetero_prob is not None:
            is_hetero = bool(rng.random() < float(hetero_prob))
        else:
            is_hetero = bool(heterophily)
        if is_hetero:
            a0, b0 = b0, a0  # 异配：让 b > a

        # === 3) 对 (a,b) 做“轻微抖动”，但不再强制 a+(k-1)b=C ===
        a = float(a0)
        b = float(b0)
        if ab_jitter and ab_jitter > 0:
            j = float(ab_jitter)
            a *= rng.uniform(1.0 - j, 1.0 + j)
            b *= rng.uniform(1.0 - j, 1.0 + j)

        # 安全下界与方向约束
        a = max(a, a_floor)
        b = max(b, b_floor)
        if keep_assortativity:
            if is_hetero:
                # 异配: b 应该 >= a
                if b < a:
                    b = max(a * 1.001, b_floor)
            else:
                # 同配: a 应该 >= b
                if a < b:
                    a = max(b * 1.001, a_floor)

        theta = sample_theta(
            mode= mode,
            N= N,
            rng=np.random.default_rng(),  # 不给 seed，就用系统时间，结果每次不同
            theta_kwargs=dict(
                gamma_shape_mode="random",
                gamma_shape_range=(1.5 , 3.0),  # 形状参数 κ 随机范围（越小越重尾）
                clip_quantile=0.999,  # 轻微去除极端 hub（可关）
            )
        )

        if normalize_theta:
            theta = _normalize_theta_global(theta)

        # === 5) 组块概率矩阵（含 log n / n 尺度）===
        logn = np.log(N)
        scale = logn / N
        B_prob = np.full((n_classes, n_classes), b * scale, dtype=float)
        np.fill_diagonal(B_prob, a * scale)

        if pq_jitter is not None:
            pj, qj = pq_jitter
            idx_d = np.eye(n_classes, dtype=bool)
            B_prob[idx_d] *= rng.uniform(1.0 - pj, 1.0 + pj, size=idx_d.sum())

            idx_od = ~np.eye(n_classes, dtype=bool)
            B_prob[idx_od] *= rng.uniform(1.0 - qj, 1.0 + qj, size=idx_od.sum())

        # 对称化（确保无向图）
        B_prob = 0.5 * (B_prob + B_prob.T)

        # 保证同配：对角线 > 非对角线
        diag_vals = np.diag(B_prob).copy()
        for r in range(n_classes):
            for s in range(n_classes):
                if r != s and B_prob[r, s] >= diag_vals[r]:
                    B_prob[r, s] = max(1e-12, diag_vals[r] * 0.9)  # 压低一点，确保 p > q

        # 返回（不含 snr）
        return B_prob, labels, theta, a, b, gamma, is_hetero, C_used


    def create_dataset_grid_dcsbm(self, directory, mode='train', *,
                                  snr_grid=(0.6, 0.9, 1.1, 1.3, 1.6, 2.0, 2.5, 3.0),
                                  gamma_grid=(0.15, 0.3, 0.6, 1.0, 2.0),
                                  C_grid=(10.0,),
                                  per_cell=20,
                                  min_size=5,
                                  base_seed=0,
                                  # DCSBM 额外可控
                                  theta_dist='pareto', theta_kwargs=None,
                                  normalize_theta=True,
                                  return_eigvecs=False, topk=8):
        """
        在 (SNR × gamma × C) 网格上生成 DCSBM 数据；每格 per_cell 张图。
        """
        os.makedirs(directory, exist_ok=True)

        if mode == 'train':
            N = self.N_train
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_train = directory
        elif mode == 'val':
            N = self.N_val
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_val = directory
        elif mode == 'test':
            N = self.N_test
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_test = directory
        else:
            raise ValueError(f"Unsupported mode: {mode}")

        idx = 0
        for c_idx, C in enumerate(C_grid):
            for s_idx, snr_target in enumerate(snr_grid):
                for g_idx, gamma in enumerate(gamma_grid):
                    cell_seed = base_seed + (c_idx * 10_000_000 + s_idx * 10_000 + g_idx * 100)
                    rng = np.random.default_rng(cell_seed)

                    for rep in range(per_cell):
                        # === 核心改动：用 DCSBM 的参数生成器 ===
                        rand_N = int(N + (np.random.rand() * 2 - 1) * 500)
                        (B_prob, labels, theta,
                        a, b, gamma_val, is_hetero, C_used) = self.gen_one_dcsbm_by_targets(
                            N=rand_N, n_classes=self.n_classes, C=C, mode = mode,
                            target_snr=snr_target, gamma=gamma,
                            min_size=min_size, rng=rng,
                            ab_jitter=0.05, keep_assortativity=True,
                            pq_jitter=(0.02, 0.05), C_jitter=0.1, C_jitter_mode='relative',
                            b_floor=1e-6,
                            # theta 控制
                            theta_dist=theta_dist, theta_kwargs=theta_kwargs,
                            normalize_theta=normalize_theta
                        )

                        # === 采样 DCSBM 图 ===
                        W_dense, labels_out, eigvecs_top = self.imbalanced_DCSBM_multiclass(
                            B_prob=B_prob, labels=labels, theta=theta,
                            rng=rng, return_eigvecs=return_eigvecs, topk=topk
                        )

                        W_sparse = csr_matrix(W_dense)

                        fname = (f"{mode}_N{rand_N}_i{idx:05d}"
                                 f"__C{C:.2f}__snr{snr_target:.3f}"
                                 f"__g{gamma:.3f}__rep{rep:02d}.npz")
                        path = os.path.join(directory, fname)

                        # === 存盘：保留 DCSBM 关键元数据（θ、ρ、B_prob 等）===
                        np.savez_compressed(
                            path,
                            adj_data=W_sparse.data,
                            adj_indices=W_sparse.indices,
                            adj_indptr=W_sparse.indptr,
                            adj_shape=W_sparse.shape,
                            labels=labels_out.astype(np.int32),
                            # 记录 a,b,C, snr
                            a=a, b=b, C=C, C_used=C_used,
                            snr_target=snr_target,
                            gamma=gamma_val,
                            # DCSBM 关键
                            theta=theta.astype(np.float32),
                            B_prob=B_prob.astype(np.float32),
                            # （可选）特征向量
                        )
                        idx += 1

        print(f"[{mode}] (DCSBM) 网格数据完成: 共 {idx} 张（期望 {num_graphs_expected}）。目录: {directory}")

    def copy(self):
        return copy.deepcopy(self)